{
 "metadata": {
  "name": "",
  "signature": "sha256:86c450cd21f30715fb88c7b8985328f982689c1999d4e1b1db439e5bf28f1117"
 },
 "nbformat": 3,
 "nbformat_minor": 0,
 "worksheets": [
  {
   "cells": [
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "import torch\n",
      "from torch.utils import data\n",
      "import numpy as np\n",
      "import pandas as pd\n",
      "import cv2"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [],
     "prompt_number": 38
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "class FaceDataset(data.Dataset):\n",
      "    def __init__(self, root):\n",
      "        super(FaceDataset, self).__init__()\n",
      "        self.root = root\n",
      "        pic = pd.read_csv(self.root+\"/dataset.csv\", header=None, usecols=[0])\n",
      "        label = pd.read_csv(self.root+\"/dataset.csv\", header=None, usecols=[1])\n",
      "        self.pic = np.array(pic)[:, 0]\n",
      "        self.label = np.array(label)[:, 0]\n",
      "    \n",
      "    def __getitem__(self, index):\n",
      "        face = cv2.imread(self.root+\"/\"+self.pic[index])\n",
      "        face_gray = cv2.cvtColor(face, cv2.COLOR_BGR2GRAY)\n",
      "        face_hist = cv2.equalizeHist(face_gray)\n",
      "        face_normalized = face_hist.reshape((1, 48, 48)) / 255.0\n",
      "        face_tensor = torch.from_numpy(face_normalized)\n",
      "        face_tensor = face_tensor.type('torch.FloatTensor')\n",
      "        label = self.label[index]\n",
      "        return face_tensor, label\n",
      "    \n",
      "    def __len__(self):\n",
      "        return self.pic.shape[0]"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [],
     "prompt_number": 39
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "path = r\"/media/root/515e7d3a-49ac-40be-ba58-fef9702d123c/work_record/\u4f5c\u4e1a3/face/train\"\n",
      "data = FaceDataset(path)\n",
      "data.__len__()"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [
      {
       "metadata": {},
       "output_type": "pyout",
       "prompt_number": 40,
       "text": [
        "24000"
       ]
      }
     ],
     "prompt_number": 40
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "face.shape"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [
      {
       "metadata": {},
       "output_type": "pyout",
       "prompt_number": 29,
       "text": [
        "torch.Size([1, 48, 48])"
       ]
      }
     ],
     "prompt_number": 29
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [],
     "language": "python",
     "metadata": {},
     "outputs": []
    }
   ],
   "metadata": {}
  }
 ]
}